{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/lokwq/TextBrewer/blob/add_note_examples/msra.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3Kh5ekwgrCG4"
   },
   "source": [
    "This notebook shows how to fine-tune a model on msra_ner datasets and how to distill the model with TextBrewer.\n",
    "\n",
    "Detailed Docs can be find here:\n",
    "https://github.com/airaria/TextBrewer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "xi_1qOral-LE"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "device='cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4cahHLZXmI81"
   },
   "outputs": [],
   "source": [
    "!pip install datasets\n",
    "!pip install transformers\n",
    "!pip install seqeval\n",
    "!pip install textbrewer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "lfq_aPKhe9jJ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from transformers import BertForSequenceClassification, BertTokenizer,BertConfig,BertForTokenClassification\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from transformers import pipeline\n",
    "from datasets import load_dataset,load_metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NvnDEfRewRhJ"
   },
   "source": [
    "### Prepare dataset to train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "ZtoFe4FPmWh_"
   },
   "outputs": [],
   "source": [
    "task = \"ner\" #  \"ner\", \"pos\" or \"chunk\"\n",
    "model_checkpoint = \"bert-base-chinese\"\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OhvaZdRqmow6"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset, load_metric\n",
    "datasets = load_dataset(\"msra_ner\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "WgM6aADvnIm5",
    "outputId": "500364c0-909a-47ae-c785-686dc2eccb12"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']"
      ]
     },
     "execution_count": 6,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_list = datasets[\"train\"].features[f\"{task}_tags\"].feature.names\n",
    "label_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vO48GEqlnRuh"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "6fO7iYQanY1G"
   },
   "outputs": [],
   "source": [
    "def tokenize_and_align_labels(examples):\n",
    "    tokenized_inputs = tokenizer(examples[\"tokens\"], truncation=True, is_split_into_words=True)\n",
    "\n",
    "    labels = []\n",
    "    for i, label in enumerate(examples[f\"{task}_tags\"]):\n",
    "        word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
    "        previous_word_idx = None\n",
    "        label_ids = []\n",
    "        for word_idx in word_ids:\n",
    "            # Special tokens have a word id that is None. We set the label to -100 so they are automatically\n",
    "            # ignored in the loss function.\n",
    "            if word_idx is None:\n",
    "                label_ids.append(-100)\n",
    "            # We set the label for the first token of each word.\n",
    "            elif word_idx != previous_word_idx:\n",
    "                label_ids.append(label[word_idx])\n",
    "            # For the other tokens in a word, we set the label to either the current label or -100, depending on\n",
    "            # the label_all_tokens flag.\n",
    "            else:\n",
    "                label_ids.append(label[word_idx] if label_all_tokens else -100)\n",
    "            previous_word_idx = word_idx\n",
    "\n",
    "        labels.append(label_ids)\n",
    "\n",
    "    tokenized_inputs[\"labels\"] = labels\n",
    "    return tokenized_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "n9iuLfNDoqz6"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_metrics(p):\n",
    "    predictions, labels = p\n",
    "    predictions = np.argmax(predictions, axis=2)\n",
    "\n",
    "    # Remove ignored index (special tokens)\n",
    "    true_predictions = [\n",
    "        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "    true_labels = [\n",
    "        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "\n",
    "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
    "    return {\n",
    "        \"precision\": results[\"overall_precision\"],\n",
    "        \"recall\": results[\"overall_recall\"],\n",
    "        \"f1\": results[\"overall_f1\"],\n",
    "        \"accuracy\": results[\"overall_accuracy\"],\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "S9X0l225n7WZ"
   },
   "outputs": [],
   "source": [
    "from transformers import BertForTokenClassification, TrainingArguments, Trainer\n",
    "\n",
    "model = BertForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "87HPRyytoYb8"
   },
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForTokenClassification\n",
    "data_collator = DataCollatorForTokenClassification(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "97wY0abnobYe"
   },
   "outputs": [],
   "source": [
    "metric = load_metric(\"seqeval\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xxl_04jup6ep"
   },
   "outputs": [],
   "source": [
    "tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "ENBYA77ToLwV"
   },
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    f\"test-{task}\",\n",
    "    evaluation_strategy = \"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    num_train_epochs=2,\n",
    "    weight_decay=0.01,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "Qaj3gN7xovJ0"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NzxXm1K1o_Uh"
   },
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wRiDIS5WuJ7s"
   },
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), '/content/drive/MyDrive/msra_teacher_model.pt') #save the teacher model weights to distill"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e5fK_phaslqc"
   },
   "source": [
    "### Start distiilation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z03IrKzzXYMZ"
   },
   "outputs": [],
   "source": [
    "import textbrewer\n",
    "from textbrewer import GeneralDistiller\n",
    "from textbrewer import TrainingConfig, DistillationConfig\n",
    "from transformers import BertForTokenClassification, BertConfig, AdamW,BertTokenizer\n",
    "from transformers import get_linear_schedule_with_warmup\n",
    "import torch "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YloYaVoWs14s"
   },
   "source": [
    "Initialize the student model by BertConfig and prepare the teacher model.\n",
    "\n",
    "bert_config_L3.json refer to a 3-layer Bert."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_P7a1W9NXdH9"
   },
   "outputs": [],
   "source": [
    "bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config_L3.json') \n",
    "bert_config_T3.output_hidden_states = True\n",
    "bert_config_T3.num_labels = len(label_list)\n",
    "\n",
    "student_model = BertForTokenClassification(bert_config_T3)\n",
    "\n",
    "bert_config = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config.json')\n",
    "bert_config.output_hidden_states = True\n",
    "bert_config.num_labels = len(label_list)\n",
    "\n",
    "teacher_model = BertForTokenClassification(bert_config) \n",
    "teacher_model.load_state_dict(torch.load('/content/drive/MyDrive/msra_teacher_model.pt'))\n",
    "\n",
    "teacher_model.to(device)\n",
    "student_model.to(device)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MKC_jnkuxTJn"
   },
   "source": [
    "The cell below is to distill the teacher model to student model you prepared.\n",
    "\n",
    "After the code execution is complete, the distilled model will be in 'saved_model' in colab file list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "C6K8OyXuXgsh"
   },
   "outputs": [],
   "source": [
    "num_epochs = 20\n",
    "num_training_steps = len(train_dataloader) * num_epochs\n",
    "\n",
    "optimizer = AdamW(student_model.parameters(), lr=1e-5)\n",
    "\n",
    "scheduler_class = get_linear_schedule_with_warmup\n",
    "\n",
    "scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}\n",
    "\n",
    "def simple_adaptor(batch, model_outputs):\n",
    "  return {\"logits\":model_outputs.logits, 'hidden': model_outputs.hidden_states}\n",
    "\n",
    "distill_config = DistillationConfig(\n",
    "    intermediate_matches=[{\"layer_T\":0, \"layer_S\":0, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1},\n",
    "               {\"layer_T\":4, \"layer_S\":1, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1},\n",
    "               {\"layer_T\":8, \"layer_S\":2, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1},\n",
    "               {\"layer_T\":12,\"layer_S\":3, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1}])\n",
    "\n",
    "train_config = TrainingConfig()\n",
    "distiller = GeneralDistiller(\n",
    "    train_config=train_config, distill_config=distill_config,\n",
    "    model_T=teacher_model, model_S=student_model, \n",
    "    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)\n",
    "\n",
    "\n",
    "with distiller:\n",
    "    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BmTWyy9dyuqG"
   },
   "source": [
    "Then evaluate the distilled model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WTU4sYE-fKN2"
   },
   "outputs": [],
   "source": [
    "bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/data/bert_config/bert_config_L3.json')\n",
    "\n",
    "bert_config_T3.output_hidden_states = True\n",
    "bert_config_T3.num_labels = len(label_list)\n",
    "test_model = BertForTokenClassification(bert_config_T3)\n",
    "\n",
    "\n",
    "test_model.load_state_dict(torch.load('/content/drive/MyDrive/model/gs2813.pkl'))\n",
    "test_model.to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2pc2iuCAZmbh"
   },
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    f\"distill-test\",\n",
    "    evaluation_strategy = \"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    do_train=False,\n",
    "    do_eval=True,\n",
    "    no_cuda=False,\n",
    "    num_train_epochs=2,\n",
    "    weight_decay=0.01,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "m3VwmhxvZUDR"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    test_model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "32q6n0fGzjtX"
   },
   "outputs": [],
   "source": [
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4hq_TkiP1-I6"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyNntqYlaAMJtt9+jRm0Faul",
   "collapsed_sections": [],
   "include_colab_link": true,
   "mount_file_id": "1SveTz-CxoL2Dt7xnBSssnhnUzKf6DEth",
   "name": "msra.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
